from models.dataset import ArxivDataset
from models.trainer import ModelTrainer

if __name__ == "__main__":
    # 先验证数据加载
    dataset = ArxivDataset(
        data_path="data/arxiv-metadata-oai-snapshot.json",
        tokenizer_path="configs"
    )
    print("数据类型验证：")
    print(dataset.data.dtypes)
    print("类别示例：", dataset.data['categories'].cat.categories[:5])

    # 使用相对路径（确保路径正确）
    trainer = ModelTrainer(
        config_path="configs",  # 指向包含config.json的目录
        data_path="data/arxiv-metadata-oai-snapshot.json"
    )
    trainer.train()